#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/27 14:07
# @Author  : Deyu.Tian
# @Site    : ChangGuang Satellite
# @File    : Potential_changes.py
# @Software: PyCharm Community Edition

import Image
import Image_process
import numpy as np
import sys
import os
import Config
import Visulazation as visul


def getPCC(bandpath, yuzhi, dth_path, pcc_path, min_area):
    """
    get PCC
    :param bandpath:
    :param yuzhi:yuzhi of diff map
    :param dth_path: dth path
    :param pcc_path: pcc path
    :param min_area: min area of pcc
    :return:
    """
    dth, coeffs = Yuzhi_band(bandpath, yuzhi, dth_path)
    print(dth.shape)
    print("Dth saved as tiff")
    labels = Image_process.conn_4_neibour(dth)
    unique, counts = np.unique(labels, return_counts= True)
    print(unique.shape, counts.shape)
    for i in range(len(unique)):
        if counts[i] < min_area:
            labels[labels == i] = 0
    # print np.asarray((unique, counts)).T
    Image.array2rasterUTM(pcc_path, coeffs, labels)
    pass


def Yuzhi_band(bandpath, yuzhi, outpath):
    # Image_process.draw_hist(bandpath)
    arr = Image.img2array(bandpath)
    coeffs =Image.read_tif_metadata(bandpath)
    arr[arr > yuzhi] = 1000
    arr[arr < yuzhi] = 0
    Image.array2rasterUTM(outpath, coeffs, arr)
    return arr, coeffs


def normalize_arr(inarr):
    """
    normalize array to [0, 1]
    :param inarr:
    :return:
    """
    return (inarr-np.max(inarr))/(np.max(inarr)-np.min(inarr))
    pass


def diff_map_pixel(imsatpath, imuavpath):
    """
    简单的逐像元差异图计算
    优点是效率高速度快，
    缺点是无法避免配准误差带来的椒盐噪声和重影现象
    just pixel by pixel diff map compute
    :param imsatpath:卫星图像的路径
    :param imuavpath:无人机图像的路径
    :return:两幅图像的差值图
    """
    imsatarr, imuavarr = np.moveaxis(Image.img2array(imsatpath), 0, 2), np.moveaxis(Image.img2array(imuavpath), 0, 2)
    imuavarr = imuavarr[:, 1:, :]
    print(imsatarr.shape, imuavarr.shape)
    # visul.visul_arr_rgb(imsatarr)
    # visul.visul_arr_rgb(imuavarr)
    imsat_desc, imuav_desc = descriptor(imsatarr), descriptor(imuavarr)
    print("features of IMUAV and IMSAT have generated successed, "
          "now compute difference maps with matrix ways!")
    diffarr = np.zeros((imuavarr.shape[0], imuavarr.shape[1]), dtype='f')
    for i in range(0, imuav_desc.shape[0]):
        for j in range(0, imuav_desc.shape[1]):
            diffarr[i, j] = np.linalg.norm(imuav_desc[i, j] - imsat_desc[i, j])
            #print(imuav_desc[i, j], imsat_desc[i, j], imuav_desc[i, j] - imsat_desc[i, j], diffarr[i, j])
    np.save(os.path.join(Config.data, "diffmap.npy"), diffarr)
    imggt = Image.read_tif_metadata(imsatpath)
    Image.array2rasterUTM(os.path.join(Config.data, "diffmap.tif"), imggt, diffarr)
    return diffarr


def diff_map_matrix(imsatpath, imuavpath, w=0):
    """
    使用矩阵运算提高差异图计算中的W邻域搜索效率
    use matrix to faster computation
    :param imsatpath:卫星图像的路径
    :param imuavpath:无人机图像的路径
    :param w:搜索邻域块的大小
    :return:两幅图像的差值图
    """
    if w == 0 or w % 2 == 0:
        print("use this function like below: \n"
              "diff_map(imsatpath, imuavpath, w)\n need param w of search box,"
              " its must be odd like 1, 3, 5, 7, 9,.etc.")
        sys.exit(1)
    imsatarr, imuavarr = np.moveaxis(Image.img2array(imsatpath), 0, 2), np.moveaxis(Image.img2array(imuavpath), 0, 2)

    imuavarr = imuavarr[:, 1:, :]#这里应设置条件判断

    print(imsatarr.shape, imuavarr.shape)
    # visul.visul_arr_rgb(imsatarr)
    # visul.visul_arr_rgb(imuavarr)
    imsat_desc, imuav_desc = descriptor(imsatarr), descriptor(imuavarr)
    print("features of IMUAV and IMSAT have generated successed, "
          "now compute difference maps with matrix ways!")
    diffarr = np.zeros((imuavarr.shape[0], imuavarr.shape[1]), dtype='f')
    for i in range((w-1)/2, imuav_desc.shape[0]-(w-1)/2):
        for j in range((w-1)/2, imuav_desc.shape[1]-(w-1)/2):
            xmin, xmax = j-(w-1)/2, j+(w-1)/2+1
            ymin, ymax = i-(w-1)/2, i+(w-1)/2+1
            v = imsat_desc[ymin:ymax, xmin:xmax]
            v_2d = v.reshape((v.shape[0] * v.shape[1]), v.shape[2])
            u = imuav_desc[i, j].reshape(1, imuav_desc[i, j].shape[0])
            diffarr[i, j] = np.min(np.linalg.norm(v_2d - u, axis=1))
            #print(v.shape, v_2d.shape, u.shape, diffarr[i, j])
            #print(imuav_desc[i, j], u, diffarr[i, j])
    np.save(os.path.join(Config.data, "diffmap_w9.npy"), diffarr)
    imggt = Image.read_tif_metadata(imsatpath)
    Image.array2rasterUTM(os.path.join(Config.data, "diffmap_w9.tif"), imggt, diffarr)
    return diffarr


def diff_map(imsatpath, imuavpath, w=0):
    """
    被我抛弃的差异图计算函数
    input imuav and imsat, return different map
    :param imsatpath:
    :param imuavpath:
    :param w:search box in imsat, w must be odd
    :return:
    """
    if w == 0 or w % 2 == 0:
        print("use this function like below: \n"
              "diff_map(imsatpath, imuavpath, w)\n need param w of search box,"
              " its must be odd like 1, 3, 5, 7, 9,.etc.")
        sys.exit(1)
    imsatarr, imuavarr = np.moveaxis(Image.img2array(imsatpath), 0, 2), np.moveaxis(Image.img2array(imuavpath), 0, 2)
    print(imsatarr.shape, imuavarr.shape)
    imsat_desc, imuav_desc = descriptor(imsatarr), descriptor(imuavarr)
    print("descriptors of IMUAV and IMSAT have generated successed, now compute difference maps!")
    diffarr = np.zeros((imuavarr.shape[0], imuavarr.shape[1]), dtype='f')
    for i in range(imuav_desc.shape[0]):
        for j in range(imuav_desc.shape[1]):
            xmin, xmax = int(j-(w-1)/2), int(j+(w-1)/2+1)
            ymin, ymax = int(i-(w-1)/2), int(i+(w-1)/2+1)
            distsset = []
            for x in range(xmin, xmax):
                for y in range(ymin, ymax):
                    if imuav_desc.shape[1] > x > 0 and 0 < y < imuav_desc.shape[0]:
                        distsset.append(np.linalg.norm(imsat_desc[x, y] - imuav_desc[i, j])) #norm1 of vector/matrix
            diffarr[i, j] = min(distsset)
    np.save(os.path.join(Config.data, "diffmap.npy"), diffarr)
    return diffarr


def descriptor(rgbarr):
    """
    描述器生成函数
    input rgbarr, return relative 36d features set
    :param rgbarr:
    :return:
    """
    desc = np.zeros((rgbarr.shape[0], rgbarr.shape[1], 36), dtype='i')
    print(desc.shape)


    desc[:, :, 0:9] = Image_process.band_9_neibour_layers_ignoreedge(rgbarr[:, :, 0])
    desc[:, :, 9:18] = Image_process.band_9_neibour_layers_ignoreedge(rgbarr[:, :, 1])
    desc[:, :, 18:27] = Image_process.band_9_neibour_layers_ignoreedge(rgbarr[:, :, 2])

    luminate = rgbarr[:, :, 0] * 0.299 + rgbarr[:, :, 1] * 0.587 + rgbarr[:, :, 2] * 0.114
    visul.visul_arr_rgb(luminate)
    desc[:, :, 27:36] = getIG(np.uint8(luminate)) #if( sdepth == CV_16S && ddepth == CV_32F)
    print("max of IG is ", np.max(desc[:, :, 27:36]))

    return desc


def getIG(lumimgarr):
    """
    9邻域梯度图层生成函数
    get IG(stacked layers of pixel-and-its 9-neibours' gradients )
    :param lumimgarr:array of Y
    :return:stacked band layers of gradients
    """
    gradient = Image_process.sobel(lumimgarr, kernel_size=5)
    return Image_process.band_9_neibour_layers_ignoreedge(gradient)


if __name__ == '__main__':
    # 主函数入口,测试使用
    # diff_map_matrix(Config.ImSat, Config.ImUAV, 9)
    # diff_map_pixel(Config.ImSat, Config.ImUAV)

    bandpath = os.path.join(Config.data, "Dth\\diffmap_w21.tif")
    yuzhi = 450.18
    dth_path = os.path.join(Config.data, "Dth\\Dth_{}.tif".format(yuzhi))
    pcc_path = os.path.join(Config.data, "Dth\\PCC.tif")
    getPCC(bandpath, yuzhi, dth_path, pcc_path, min_area = 100)